import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import json
import bluepysnap
import matplotlib.patches as patches
import numpy as np


def get_spindle_amp(sim):
    report = sim.spikes["thalamus_neurons"].spike_report.filter(group={"mtype": "Rt_RC"}).report

    times = report.index
    time_start = np.min(times)
    time_stop = np.max(times)
    time_binsize = 5
    bins = np.append(np.arange(time_start, time_stop, time_binsize), time_stop)
    hist, bin_edges = np.histogram(times, bins=bins)
    node_count = report[["ids", "population"]].drop_duplicates().shape[0]
    freq = 1.0 * hist / node_count / (0.001 * time_binsize)

    from scipy.signal import find_peaks

    fs = []
    fs1 = []
    for step in range(16):
        low, high = 3500 + step * 1000, 4500 + step * 1000
        mask = (bin_edges[:-1] >= low) & (bin_edges[:-1] <= high)
        f = freq[mask]
        t = bins[:-1][mask]
        peaks = find_peaks(f, height=2, width=3)[0]
        if peaks.size == 0:
            fs.append(0)
            fs1.append(0)
        else:
            fs.append(np.mean(f[peaks]))
            fs1.append(f[peaks[0]])
    return fs, fs1


def make_plot_PSTH_up_down_RT(sim, ts, te, UP_state_start_times, DOWN_state_start_times, ax):
    RT_gids = sim.circuit.nodes["thalamus_neurons"].ids({"region": "mc2;Rt", "mtype": "Rt_RC"})
    spikes_df = sim.spikes.filter(group=RT_gids, t_start=ts, t_stop=te).report
    spike_times = spikes_df.index.to_numpy()
    spiking_gids = np.unique(spikes_df.ids.to_numpy())

    ax.set_ylabel("PSTH [Hz]")
    ylim = 180
    ax.set_ylim(0, ylim)
    ax.set_xlim(3000, 19000)

    time_binsize = None
    times = 0

    time_start = np.min(spike_times)
    time_stop = np.max(spike_times)

    if time_binsize is None:
        # heuristic for a nice bin size (~100 spikes per bin on average)
        time_binsize = min(10.0, (time_stop - time_start) / ((len(spike_times) / 100.0) + 1.0))

    bins = np.append(np.arange(time_start, time_stop, time_binsize), time_stop)
    hist, bin_edges = np.histogram(spike_times, bins=bins)

    node_count = len(np.unique(spiking_gids))
    freq = 1.0 * hist / node_count / (0.001 * time_binsize)

    # use the middle of the bins instead of the start of the bin
    ax.plot(0.5 * (bin_edges[1:] + bin_edges[:-1]), freq, drawstyle="steps-mid")
    ax.spines.top.set_visible(False)
    ax.spines.right.set_visible(False)

    for i, t in enumerate(UP_state_start_times):
        points = [(t, -100), (t, ylim), (t + 500, ylim), (t + 500, -100)]
        rect = patches.Polygon(
            points,
            linewidth=1,
            edgecolor="none",
            facecolor="grey",
            alpha=0.2,
        )
        ax.add_patch(rect)



if __name__ == "__main__":

    sim_root = "../../../simulations/06_09_Iahp_variants/"

    ts, te = 1500, 20000

    colors = {
        "Ecel": "springgreen",
        "Spp": "darkorchid",
        "Runaway": "purple",
        "intermedEcel": "red",
        "intermedSpp": "blue",
        "mixed50pEcel50pSppRunaway": "green",
        "mixed75pEcel25pSppRunaway": "orange",
        "mixed25pEcel75pSppRunaway": "grey",
    }

    sim_type_list = [
        "variant14_lowest-Iahp",
        #"variant15_low-Iahp",
        "variant2_Spp",
        #"variant16_med-Iahp",
        "variant17_high-Iahp",
    ]

    up_start = [
        3500,
        4500,
        5500,
        6500,
        7500,
        8500,
        9500,
        10500,
        11500,
        12500,
        14500,
        15500,
        16500,
        17500,
        18500,
    ]
    up_end = [
        4000,
        5000,
        6000,
        7000,
        8000,
        9000,
        10000,
        11000,
        12000,
        13000,
        14000,
        15000,
        16000,
        17000,
        18000,
        19000,
    ]

    # Ecel composition circuits plot params
    num_plots_x, num_plots_y = 1, len(sim_type_list)
    amps = []
    stds = []
    all_amps = []
    amp1 = []
    plot_size = (15, 2)
    fig, axes = plt.subplots(
        num_plots_y, num_plots_x, figsize=(plot_size[0], plot_size[1] * num_plots_y)
    )

    amp_df = pd.DataFrame()
    for sim_id, (ax, sim_type_prefix) in enumerate(zip(axes, sim_type_list)):

        sims_list = ["_90percent_CT_up_down"]  # "_90percent_CT_up_down",

        spike_data = {"mean_num_spikes": [], "percent_CT_activated": [], "ts": ts, "te": te}
        spike_data["ts"] = ts
        spike_data["te"] = te
        for i, sim_folder in enumerate(sims_list):
            sim_folder_full_name = sim_type_prefix + "_SUPER_LONG" + sim_folder
            sim_path = (
                sim_root + sim_type_prefix + "/" + sim_folder_full_name + "/simulation_config.json"
            )
            f = open(sim_path)
            data = json.load(f)

            circuit_var = data["network"].split("/")[-1]
            percent_CT_input = data["inputs"]["spikeReplay_ct_UP1"]["spike_file"].split("_")[-4]

            print(f"\nAnalysing sim for circuit {circuit_var}, percent_CT_input {percent_CT_input}")

            _sim = bluepysnap.Simulation(sim_path)

            make_plot_PSTH_up_down_RT(
                _sim, ts, te, up_start, up_end, ax
            )

    # ax.legend(loc="upper right", fontsize=13)
    ax.set_xlabel("Time [ms]")
    plt.tight_layout()
    plt.savefig(f"variedIAHP_{sim_folder}.pdf")
